# ri_functions.R
# functions for randomisation inference (bootstrap)
# July 2023

# function to re-randomise villages within blocks, then individuals
re_random = function(indata){
  four_treatments = unique(indata$village_treatment) # four possible village-level treatments
  # re-randomise treatments at block-level
  frame = select(indata, block, q123) %>% 
    unique() %>%
    group_by(block) %>%
    mutate(n = n(),
           new_village_treat = sample(four_treatments, size=n(), replace=FALSE)) %>% # make new village-level treatment
    ungroup()
  # add new village-level treatment back to data
  outdata = left_join(indata, frame, by=c('block','q123')) # merge by block and village
  # with(outdata, table(new_village_treat, village_treatment)) # quick check
  # now sample individual treatments
  f1 = filter(outdata, new_village_treat == 'Placebo') %>%
    mutate(new_individual = 'Placebo')
  f2 = filter(outdata, new_village_treat == 'High Cash') %>%
    mutate(new_individual = sample(c('HighCash','Placebo'), size=n(), replace=TRUE, prob=c(0.75,0.25)))
  f3 = filter(outdata, new_village_treat == 'Low Cash') %>%
    mutate(new_individual = sample(c('LowCash','Placebo'), size=n(), replace=TRUE, prob=c(0.75,0.25)))
  f4 = filter(outdata, new_village_treat == 'CDC health') %>%
    mutate(new_individual = sample(c('CDCHealth','Placebo'), size=n(), replace=TRUE, prob=c(0.75,0.25)))
  outdata = bind_rows(f1,f2,f3,f4)
  if(nrow(outdata) != nrow(indata)){cat('error')}
  # with(outdata, table(new_individual, individual_treatment_ri)) # quick check
  # make treatment a factor
  outdata = mutate(outdata, new_individual = factor(new_individual),
                   new_individual = relevel(new_individual, ref='Placebo')) # reference category
  # return new data
  return(outdata)
}

## function to run bootstrap
run_boot = function(indata ,
                    n_boot = 1000,
                    equation){
  
  # make new equation
  equation_new = str_replace(equation, pattern = 'individual_treatment_ri', replacement = 'new_individual')
  
  # run original model to get observed results
  observed = glm.cluster(equation, 
                         cluster = indata$q123,
                         family = 'binomial',
                         data = indata)
  summary_observed = hush(summary(observed)) %>%
    as.data.frame() %>%
    clean_names() %>%
    tibble::rownames_to_column() %>%
    filter(str_detect(rowname, 'individual_')) %>%
    mutate(rowname = str_remove_all(rowname, pattern='individual_treatment_ri')) %>%
    select(rowname, estimate, pr_z) %>%
    rename('observed' = 'estimate',
           'observed_pval' = 'pr_z')
  
  # now bootstrap
  boot_results = NULL
  for (k in 1:n_boot){
    new_data = re_random(indata) # re-randomise
    new_model = glm.cluster(equation_new, 
                            cluster = new_data$q123,
                            family = 'binomial',
                            data = new_data)
    this_result = hush(summary(new_model)) %>%
      as.data.frame() %>%
      mutate(boot = k) %>%
      tibble::rownames_to_column() %>%
      filter(str_detect(rowname, 'new_individual'))
    boot_results = bind_rows(boot_results, this_result)
    #
    if(k%%1000==0){cat('Up to bootstrap ',k,'\n', sep='')}
  }
  boot_results = clean_names(boot_results) %>%
    mutate(rowname = str_remove_all(rowname, 'new_individual'))
  
  # plot
  gplot = ggplot(data = boot_results, aes(x = exp(estimate))) +
    geom_vline(lty=2, xintercept=1)+
    geom_histogram(col='grey77', fill='skyblue')+
    geom_vline(data = summary_observed, aes(xintercept = exp(observed)), col='red')+
    facet_wrap(~rowname)+
    theme_bw()+
    xlab('Odds ratio')+
    ylab('Frequency')
  gplot
  
  # calculate bootstrap p-values
  ests = left_join(boot_results, summary_observed, by='rowname') %>%
    mutate(diff = as.numeric(estimate - observed)>0) %>%
    group_by(rowname) %>%
    summarise(n = n(),
              r = sum(diff)) %>%
    mutate(pval1 = (r+1)/(n+1), # one-sided
           pval.dash = 1 - pval1,
           pval_bootstrap = 2*pmin(pval1, pval.dash))%>% # two-sided bayesian posterior p-value)
    ungroup() %>%
    select(-pval1, -pval.dash) %>%
    left_join(summary_observed, by='rowname') # add original results

  # return
  to_return = list()
  to_return$plot = gplot
  to_return$ests = ests
  return(to_return)
  
} # end of function

# used to avoid printing out output
hush=function(code){
  sink("NUL") # use /dev/null in UNIX
  tmp = code
  sink()
  return(tmp)
}
